/*
 * SPDX-License-Identifier: Apache-2.0
 */

//====------ ConvertONNXToStableHlo.cpp - ONNX dialects to StableHlo lowering
//-------===//
//
// Copyright 2022
//
// =============================================================================
//
// This file implements the lowering of frontend operations to a combination of
// StableHlo IR and standard operations.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"

#include "src/Conversion/ONNXToStableHlo/ONNXToStableHloCommon.hpp"

using namespace mlir;

namespace onnx_mlir {

void populateONNXToStableHloConversionPattern(
    RewritePatternSet &patterns, MLIRContext *ctx) {
  // Math
  populateLoweringONNXClipOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXElementwiseOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXGemmOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXMatMulOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXReductionOpToStableHloPattern(patterns, ctx);
  // Neural network
  populateLoweringONNXConvOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXConvTransposeOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXNormalizationOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXPoolingOpToStableHloPattern(patterns, ctx);
  // Tensor
  populateLoweringONNXArgMaxOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXConcatOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXConstantOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXExpandOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXFlattenOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXGatherOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXGatherElementsOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXIdentityOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXPadOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXReshapeOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXShapeOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXSliceOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXSplitOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXSqueezeOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXTileOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXTransposeOpToStableHloPattern(patterns, ctx);
  populateLoweringONNXUnsqueezeOpToStableHloPattern(patterns, ctx);
}

//===----------------------------------------------------------------------===//
// Frontend to StableHlo Dialect lowering pass
//===----------------------------------------------------------------------===//

struct FrontendToStableHloLoweringPass
    : public PassWrapper<FrontendToStableHloLoweringPass,
          OperationPass<ModuleOp>> {

  StringRef getArgument() const override { return "convert-onnx-to-stablehlo"; }

  StringRef getDescription() const override {
    return "Lower frontend ops to StableHlo dialect.";
  }

  // Make sure that we have a valid default constructor and copy
  // constructor to make sure that the options are initialized properly.
  FrontendToStableHloLoweringPass() = default;
  FrontendToStableHloLoweringPass(const FrontendToStableHloLoweringPass &pass)
      : PassWrapper<FrontendToStableHloLoweringPass,
            OperationPass<ModuleOp>>() {}

  void getDependentDialects(DialectRegistry &registry) const override {
    registry.insert<mlir::stablehlo::StablehloDialect>();
  }

  void runOnOperation() final;
};

void FrontendToStableHloLoweringPass::runOnOperation() {
  ModuleOp module = getOperation();
  // The first thing to define is the conversion target. This will define the
  // final target for this lowering.
  ConversionTarget target(getContext());

  // We define the specific operations, or dialects, that are legal targets for
  // this lowering.
  // Added affine as some affine maps are generated by IndexExpression. It could
  // be disabled and/or replaced by shape max/min.
  target.addLegalDialect<stablehlo::StablehloDialect, func::FuncDialect,
      arith::ArithDialect, shape::ShapeDialect, mlir::affine::AffineDialect,
      tensor::TensorDialect>();
  // Needed to support unsigned int computations. To be removed if we use a
  // scheme that does not rely on the UnrealizedConversionCastOp.
  target.addLegalOp<::mlir::UnrealizedConversionCastOp>();

  // Now that the conversion target has been defined, we just need to provide
  // the set of patterns that will lower the frontend operations.
  RewritePatternSet patterns(&getContext());

  // Define patterns.
  populateONNXToStableHloConversionPattern(patterns, &getContext());

  // add illegal op
  target.addIllegalOp<ONNXSoftmaxOp>();

  // With the target and rewrite patterns defined, we can now attempt the
  // conversion. The conversion will signal failure if any of our `illegal`
  // operations were not converted successfully.
  if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
    signalPassFailure();
  }
}

std::unique_ptr<Pass> createLowerToStableHloPass() {
  return std::make_unique<FrontendToStableHloLoweringPass>();
}

} // namespace onnx_mlir
